#    pythonequations is a collection of equations expressed as Python classes
#    Copyright (C) 2008 James R. Phillips
#    2548 Vera Cruz Drive
#    Birmingham, AL 35235 USA
#    email: zunzun@zunzun.com
#
#    License: BSD-style (see LICENSE.txt in main source directory)
#    Version info: $Id: NeuralNetwork.py 313 2011-07-11 08:45:13Z zunzun.com $

pass
'''
import pythonequations, pythonequations.EquationBaseClasses
import pythonequations.ExtraCodeForEquationBaseClasses as ex
import numpy, random
numpy.seterr(all = 'raise') # numpy raises warnings, convert to exceptions to trap them


class NeuralNetwork3D(pythonequations.EquationBaseClasses.Equation3D):
    try: # silently trap exceptions at package import time, the Initialize() method will yield an exception at run-time - user may not need neural networks
        import pybrain
        bestHiddenActivationFunction = pybrain.structure.SigmoidLayer
        bestOutputActivationFunction = pybrain.structure.SigmoidLayer
    except:
        pass
    RequiresAutoGeneratedGrowthAndDecayForms = False
    RequiresAutoGeneratedOffsetForm = False
    RequiresAutoGeneratedReciprocalForm = False
    RequiresAutoGeneratedInverseForms = False
    _name = "Neural Network"
    _HTML = "y = Neural Network 3D"
    neuralNetworkFlag = 1
    testingForActivationFunctionsFlag = False
    function_cpp_code = ';' # unused
    hiddenNodeCountLayer_1 = 2
    hiddenNodeCountLayer_2 = 2
    max_iterations = 1000
    standardOneOrBypassConnectionsZeroFlag = 0
    biasFlag = 1    
    ann = None
    fittingTarget = 'SSQABS' # see the Equation base class for a list of fitting targets.  Currently the Neural Networks have only SSQABS via MSE.
    printExampleOutput = False


    def CreateCacheGenerationList(self):
        self.CacheGenerationList = []
        self.CacheGenerationList.append([ex.CG_X(NameOrValueFlag=1), []])
        self.CacheGenerationList.append([ex.CG_Y(NameOrValueFlag=1), []])


    def Initialize(self):
        try:
            import pybrain
        except:
            raise Exception('Unable to import the PyBrain neural network library.  Please install PyBrain from http://pybrain.org/')
        
        if self.hiddenNodeCountLayer_1 < 1:
            raise Exception('The first hidden layer must have at least one node.')

        self.CreateCacheGenerationList()

        try:
            self.dataMax_Dep = max(self.DependentDataArray)
            self.dataMin_Dep = min(self.DependentDataArray)

            # manually scale data from 0.05 to 0.95, this range of values will work with all activation functions
            self.FindOrCreateCache()
            self.dataMax_Indep0 = max(self.cache[0])
            self.dataMin_Indep0 = min(self.cache[0])
            self.dataMax_Indep1 = max(self.cache[1])
            self.dataMin_Indep1 = min(self.cache[1])
            
            # scale independent data from zero to one
            self.scaled_cache = numpy.empty_like(self.cache)
            self.scaled_cache[0] = (self.cache[0] - self.dataMin_Indep0) / (self.dataMax_Indep0 - self.dataMin_Indep0)
            self.scaled_cache[1] = (self.cache[1] - self.dataMin_Indep1) / (self.dataMax_Indep1 - self.dataMin_Indep1)
            # scale independent data from 0.05 to 0.95
            self.scaled_cache[0] = (self.scaled_cache[0] * 0.9) + 0.05
            self.scaled_cache[1] = (self.scaled_cache[1] * 0.9) + 0.05
            # scale dependent data from zero to one
            self.scaled_DependentDataArray = (numpy.array(self.DependentDataArray) - self.dataMin_Dep) / (self.dataMax_Dep - self.dataMin_Dep)
            # scale dependent data from 0.05 to 0.95
            self.scaled_DependentDataArray = (self.scaled_DependentDataArray * 0.9) + 0.05
            
            # make sure we are not over maximum possible number of points
            # FindOrCreateReducedDataCache() will usually add one to this to be safe
            self.numberOfDecimatedRawDataPoints = 6 * (self.hiddenNodeCountLayer_1 + self.hiddenNodeCountLayer_2) * self.dimensionality # seems to work OK
            if self.numberOfDecimatedRawDataPoints > len(self.DependentDataArray):
                self.numberOfDecimatedRawDataPoints = len(self.DependentDataArray)

            self.FindOrCreateReducedDataCache()
            self.scaled_reducedCache = numpy.empty_like(self.reducedCache)
            self.scaled_reducedCache[0] = (self.reducedCache[0] - self.dataMin_Indep0) / (self.dataMax_Indep0 - self.dataMin_Indep0)
            self.scaled_reducedCache[1] = (self.reducedCache[1] - self.dataMin_Indep1) / (self.dataMax_Indep1 - self.dataMin_Indep1)
            self.scaled_reducedCache[0] = (self.scaled_reducedCache[0] * 0.9) + 0.05
            self.scaled_reducedCache[1] = (self.scaled_reducedCache[1] * 0.9) + 0.05
            self.scaled_reducedDependentData = (numpy.array(self.reducedDependentData) - self.dataMin_Dep) / (self.dataMax_Dep - self.dataMin_Dep)
            self.scaled_reducedDependentData = (self.scaled_reducedDependentData * 0.9) + 0.05
        except:
            pass
        
        numpy.random.seed(3)
        random.seed(3)



    def EvaluateCachedData(self, coeff, _id):
        # scale independent data from zero to one
        scaled_data0 = (_id[0] - self.dataMin_Indep0) / (self.dataMax_Indep0 - self.dataMin_Indep0)
        scaled_data1 = (_id[1] - self.dataMin_Indep1) / (self.dataMax_Indep1 - self.dataMin_Indep1)
        # scale independent data from 0.05 to 0.95
        scaled_data0 = (scaled_data0 * 0.9) + 0.05
        scaled_data1 = (scaled_data1 * 0.9) + 0.05
        
        retArray = numpy.empty(len(scaled_data0))
        for i in range(len(retArray)):
            retArray[i] = self.ann.activate([scaled_data0[i], scaled_data1[i]])
            
        # rescale output data
        retArray = (retArray - 0.05) / 0.9 # rescale zero to one
        retArray = retArray * (self.dataMax_Dep - self.dataMin_Dep) + self.dataMin_Dep
        
        return retArray


    def CalculateFittingTarget(self, in_coeffArray):
        if self.fittingTarget != "SSQABS":
            raise Exception('CalculateFittingTarget() only defined for SSQABS')
            
        error = self.EvaluateCachedData(in_coeffArray, self.cache) - self.DependentDataArray
        val = numpy.sum(error * error)
        if numpy.isfinite(val):
            return val
        else:
            return 1.0E300


    def SpecificCodeCPP(self):

        connections = []
        for key in self.ann.connections.keys():
            for connection in self.ann.connections[key]:
                connections.append(connection)

        s = ''

        for i in range(self.hiddenNodeCountLayer_1):
            s += '\tdouble hidden_layer1_node' + str(i) + ' = 0.0;\n'
        for i in range(self.hiddenNodeCountLayer_2):
            s += '\tdouble hidden_layer2_node' + str(i) + ' = 0.0;\n'
        s += '\tdouble output = 0.0;\n'
        
        s += '\n'
        
        # scale the input from 0.05 to 0.95
        s += '\tx_in = (x_in - %-.16E) / (%-.16E - %-.16E);\n' % (self.dataMin_Indep0, self.dataMax_Indep0, self.dataMin_Indep0) # zero to one
        s += '\tx_in = (x_in * 0.9) + 0.05;\n'
        s += '\ty_in = (y_in - %-.16E) / (%-.16E - %-.16E);\n' % (self.dataMin_Indep1, self.dataMax_Indep1, self.dataMin_Indep1) # zero to one
        s += '\ty_in = (y_in * 0.9) + 0.05;\n'

        s += '\n' #########################

        # inputs to first hidden layer
        for connection in connections:
            if connection.inmod.name == 'in' and connection.outmod.name == 'hidden0':
                for i in range(self.hiddenNodeCountLayer_1):
                    s += '\thidden_layer1_node' + str(i) + ' += x_in * ' + "%-.16E" % (connection.params[i*2]) + ";\n"
                    s += '\thidden_layer1_node' + str(i) + ' += y_in * ' + "%-.16E" % (connection.params[i*2+1]) + ";\n"

        # bias to first hidden layer
        for connection in connections:
            if connection.inmod.name == 'bias' and connection.outmod.name == 'hidden0':
                for i in range(self.hiddenNodeCountLayer_1):
                    s += '\thidden_layer1_node' + str(i) + ' += ' + "%-.16E" % (connection.params[i]) + ";\n"

        # activation functions for layer 1
        for i in range(self.hiddenNodeCountLayer_1):
            s += ex.GetCppCodeForActivationFunction(self.bestHiddenActivationFunction.__name__, 'hidden_layer1_node' + str(i))

        s += '\n' #########################

        # inputs to second hidden layer
        for connection in connections:
            if connection.inmod.name == 'in' and connection.outmod.name == 'hidden1':
                for i in range(self.hiddenNodeCountLayer_2):
                    s += '\thidden_layer2_node' + str(i) + ' += x_in * ' + "%-.16E" % (connection.params[i*2]) + ";\n"
                    s += '\thidden_layer2_node' + str(i) + ' += y_in * ' + "%-.16E" % (connection.params[i*2+1]) + ";\n"

        # first hidden layer to second hidden layer
        for connection in connections:
            if connection.inmod.name == 'hidden0' and connection.outmod.name == 'hidden1':
                for i in range(self.hiddenNodeCountLayer_2):
                    for j in range(self.hiddenNodeCountLayer_1):
                        s += '\thidden_layer2_node' + str(i) + ' += hidden_layer1_node' + str(j) + ' * ' + "%-.16E" % (connection.params[i*self.hiddenNodeCountLayer_2 + j]) + ";\n"

        # bias to second hidden layer
        for connection in connections:
            if connection.inmod.name == 'bias' and connection.outmod.name == 'hidden1':
                for i in range(self.hiddenNodeCountLayer_2):
                    s += '\thidden_layer2_node' + str(i) + ' += ' + "%-.16E" % (connection.params[i]) + ";\n"

        # activation functions for layer 2
        for i in range(self.hiddenNodeCountLayer_2):
            s += ex.GetCppCodeForActivationFunction(self.bestHiddenActivationFunction.__name__, 'hidden_layer2_node' + str(i))

        s += '\n' #########################

        # inputs to output
        for connection in connections:
            if connection.inmod.name == 'in' and connection.outmod.name == 'out':
                s += '\toutput += x_in * ' + "%-.16E" % (connection.params[0]) + ";\n"
                s += '\toutput += y_in * ' + "%-.16E" % (connection.params[1]) + ";\n"

        # first hidden layer to output
        for connection in connections:
            if connection.inmod.name == 'hidden0' and connection.outmod.name == 'out':
                for i in range(self.hiddenNodeCountLayer_1):
                    s += '\toutput += hidden_layer1_node' + str(i) + ' * ' + "%-.16E" % (connection.params[i]) + ";\n"

        # second hidden layer to output
        for connection in connections:
            if connection.inmod.name == 'hidden1' and connection.outmod.name == 'out':
                for i in range(self.hiddenNodeCountLayer_1):
                    s += '\toutput += hidden_layer2_node' + str(i) + ' * ' + "%-.16E" % (connection.params[i]) + ";\n"

        # bias to second hidden layer
        for connection in connections:
            if connection.inmod.name == 'bias' and connection.outmod.name == 'out':
                s += '\toutput += ' + "%-.16E" % (connection.params[0]) + ";\n"

        # activation functions for output
        s += ex.GetCppCodeForActivationFunction(self.bestOutputActivationFunction.__name__, 'output')
            
        s += '\n' #########################

        # scale the output
        s += '\toutput = (output - 0.05) / 0.9;\n' # rescale zero to one
        s += '\toutput = output * (%-.16E - %-.16E) + %-.16E;\n' % (self.dataMax_Dep, self.dataMin_Dep, self.dataMin_Dep)
        
        s += '\n'

        s += '\ttemp = output;\n'

        s += '\n'
        
        return s
'''